import datetime
import os
import time

from functools import wraps
from typing import Optional

import litellm

from tenacity import retry, stop_after_attempt, wait_fixed
from wandb.sdk.data_types.trace_tree import Trace

from cover_agent.custom_logger import CustomLogger
from cover_agent.record_replay_manager import RecordReplayManager
from cover_agent.settings.config_loader import get_settings
from cover_agent.utils import get_original_caller


def conditional_retry(func):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        if not self.enable_retry:
            return func(self, *args, **kwargs)

        model_retries = get_settings().get("default").get("model_retries", 3)

        @retry(stop=stop_after_attempt(model_retries), wait=wait_fixed(1))
        def retry_wrapper():
            return func(self, *args, **kwargs)

        return retry_wrapper()

    return wrapper


class AICaller:
    def __init__(
        self,
        model: str,
        api_base: str = "",
        enable_retry=True,
        max_tokens=16384,  # TODO: Move to configuration.toml?
        source_file: str = None,
        test_file: str = None,
        record_mode: bool = False,
        record_replay_manager: Optional[RecordReplayManager] = None,
        logger: Optional[CustomLogger] = None,
        generate_log_files: bool = True,
    ):
        """
        Initializes an instance of the AICaller class.

        Parameters:
            model (str): The name of the model to be used.
            api_base (str): The base API URL to use in case the model is set to Ollama or Hugging Face.
        """
        self.model = model
        self.api_base = api_base
        self.enable_retry = enable_retry
        self.max_tokens = max_tokens
        self.source_file = source_file
        self.test_file = test_file
        self.record_mode = record_mode
        self.record_replay_manager = record_replay_manager or RecordReplayManager(
            record_mode=record_mode, generate_log_files=generate_log_files
        )
        self.logger = logger or CustomLogger.get_logger(__name__, generate_log_files=generate_log_files)

    @conditional_retry  # You can access self.enable_retry here
    def call_model(self, prompt: dict, stream=True):
        """
        Call the language model with the provided prompt and retrieve the response.

        Parameters:
            prompt (dict): The prompt to be sent to the language model.
            stream (bool, optional): Whether to stream the response or not. Defaults to True.

        Returns:
            tuple: A tuple containing the response generated by the language model, the number of tokens used from the prompt, and the total number of tokens in the response.
        """
        caller_name = get_original_caller()

        if "system" not in prompt or "user" not in prompt:
            raise KeyError("The prompt dictionary must contain 'system' and 'user' keys.")
        if prompt["system"] == "":
            messages = [{"role": "user", "content": prompt["user"]}]
        else:
            if self.model in ["o1-preview", "o1-mini"]:
                # o1 doesn't accept a system message so we add it to the prompt
                messages = [
                    {
                        "role": "user",
                        "content": prompt["system"] + "\n" + prompt["user"],
                    },
                ]
            else:
                messages = [
                    {"role": "system", "content": prompt["system"]},
                    {"role": "user", "content": prompt["user"]},
                ]

        # Default completion parameters
        completion_params = {
            "model": self.model,
            "messages": messages,
            "stream": stream,  # Use the stream parameter passed to the method
            "temperature": 0.2,  # TODO: Move to configuration.toml?
            "max_tokens": self.max_tokens,
        }

        # Model-specific adjustments
        if self.model in ["o1-preview", "o1-mini", "o1", "o3-mini"]:
            stream = False  # o1 doesn't support streaming
            completion_params["temperature"] = 1
            completion_params["stream"] = False  # o1 doesn't support streaming
            completion_params["max_completion_tokens"] = 2 * self.max_tokens
            # completion_params["reasoning_effort"] = "high"
            completion_params.pop("max_tokens", None)  # Remove 'max_tokens' if present

        # API base exception for OpenAI Compatible, Ollama, and Hugging Face models
        if "ollama" in self.model or "huggingface" in self.model or self.model.startswith("openai/"):
            completion_params["api_base"] = self.api_base

        try:
            self.logger.info(f"📣 Calling LLM from {caller_name}()...")
            response = litellm.completion(**completion_params)
        except Exception as e:
            self.logger.error(f"Error calling LLM model: {e}")
            raise e

        if stream:
            chunks = []
            self.logger.info("Streaming results from LLM model...")
            try:
                for chunk in response:
                    print(chunk.choices[0].delta.content or "", end="", flush=True)
                    chunks.append(chunk)
                    # Optional: Delay to simulate more 'natural' response pacing
                    time.sleep(0.01)

            except Exception as e:
                self.logger.error(f"Error calling LLM model during streaming: {e}")
                if self.enable_retry:
                    raise e
            model_response = litellm.stream_chunk_builder(chunks, messages=messages)
            print("\n")
            # Build the final response from the streamed chunks
            content = model_response["choices"][0]["message"]["content"]
            usage = model_response["usage"]
            prompt_tokens = int(usage["prompt_tokens"])
            completion_tokens = int(usage["completion_tokens"])
        else:
            # Non-streaming response is a CompletionResponse object
            content = response.choices[0].message.content
            self.logger.info("Printing results from LLM model...")
            print(content)
            usage = response.usage
            prompt_tokens = int(usage.prompt_tokens)
            completion_tokens = int(usage.completion_tokens)

        if "WANDB_API_KEY" in os.environ:
            try:
                root_span = Trace(
                    name="inference_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
                    kind="llm",  # kind can be "llm", "chain", "agent", or "tool"
                    inputs={
                        "user_prompt": prompt["user"],
                        "system_prompt": prompt["system"],
                    },
                    outputs={"model_response": content},
                )
                root_span.log(name="inference")
            except Exception as e:
                self.logger.error(f"Error logging to W&B: {e}")

        if self.record_mode and self.source_file and self.test_file:
            self.record_replay_manager.record_response(
                self.source_file,
                self.test_file,
                prompt,
                content,
                prompt_tokens,
                completion_tokens,
                caller_name,
            )

        # Returns: Response, Prompt token count, and Completion token count
        return content, prompt_tokens, completion_tokens
